"""
nuScenes雷达点云多传感器可视化脚本
功能：在3D空间中显示同一场景下五个雷达传感器的点云
环境要求：Python 3.6+，需安装以下包：
  pip install nuscenes-devkit matplotlib numpy pyquaternion
"""

import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from nuscenes import NuScenes
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import LidarPointCloud, RadarPointCloud

# 配置参数
DATASET_ROOT = "./mini"  # 必须修改为实际数据集路径
SCENE_INDEX = 0                                # 场景索引（0表示第一个场景）
MAX_DISTANCE = 100.0                           # 最大可视距离（米）
POINT_SIZE = 2                                 # 点云大小
ALPHA = 0.6                                    # 点云透明度

# 雷达传感器列表（nuScenes mini包含的五个雷达）
RADAR_NAMES = [
    "RADAR_FRONT",
    "RADAR_FRONT_LEFT",
    "RADAR_FRONT_RIGHT",
    "RADAR_BACK_LEFT",
    "RADAR_BACK_RIGHT",
]

# 颜色映射（每个雷达对应不同颜色）
COLOR_MAP = {
    "RADAR_FRONT": "red",
}
    # "RADAR_FRONT_LEFT": "green",
    # "RADAR_FRONT_RIGHT": "blue",
    # "RADAR_BACK_LEFT": "cyan",
    # "RADAR_BACK_RIGHT": "magenta",

def load_radar_data(nusc, sample_token, radar_name):
    """加载并转换指定雷达的点云数据到全局坐标系"""
    # 获取雷达数据记录
    radar_data = nusc.get("sample_data", sample_token["data"][radar_name])
    
    # 加载点云二进制文件
    file_path = os.path.join(nusc.dataroot, radar_data["filename"])
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"雷达文件不存在：{file_path}")
    pc = RadarPointCloud.from_file(file_path)
    #pc = np.fromfile(file_path, dtype=np.float32).reshape((-1, 7)).T
    pc = pc.points.T  # (4, num_points) 或 (7, num_points)，取决于雷达类型

    # # 坐标系转换：雷达坐标系 -> 车辆坐标系
    # calib = nusc.get("calibrated_sensor", radar_data["calibrated_sensor_token"])
    # points = np.vstack((pc[:3, :], np.ones(pc.shape[1])))  # 齐次坐标
    # points_vehicle = calib["rotation"].rotation_matrix @ points[:3, :] + np.array(calib["translation"]).reshape(-1, 1)

    # # 坐标系转换：车辆坐标系 -> 全局坐标系
    # ego_pose = nusc.get("ego_pose", radar_data["ego_pose_token"])
    # points_global = ego_pose["rotation"].rotation_matrix @ points_vehicle + np.array(ego_pose["translation"]).reshape(-1, 1)

    return pc[:, :3]  # (num_points, 3)

def setup_3d_plot():
    """配置3D可视化画布"""
    fig = plt.figure(figsize=(16, 10))
    ax = fig.add_subplot(111, projection="3d")
    ax.view_init(elev=30, azim=-90)  # 俯视视角
    ax.set_xlabel("X (m)", fontsize=10)
    ax.set_ylabel("Y (m)", fontsize=10)
    ax.set_zlabel("Z (m)", fontsize=10)
    return fig, ax

def plot_radar_points(ax, points, color, label):
    """绘制单个雷达的点云"""
    if points.size == 0:
        return
    # 应用距离过滤
    distances = np.linalg.norm(points, axis=1)
    valid_points = points[distances < MAX_DISTANCE]
    # 绘制点云
    ax.scatter(
        valid_points[:, 0],
        valid_points[:, 1],
        valid_points[:, 2],
        c=color,
        s=POINT_SIZE,
        label=label,
        alpha=ALPHA,
        depthshade=False,
    )

def main():
    # 检查数据集路径
    if not os.path.exists(DATASET_ROOT):
        raise FileNotFoundError(f"数据集路径不存在：{DATASET_ROOT}")

    # 初始化数据集
    print(f"\n正在加载nuScenes mini数据集...")
    nusc = NuScenes(version="v1.0-mini", dataroot=DATASET_ROOT, verbose=False)
    
    # 验证场景索引
    if SCENE_INDEX >= len(nusc.scene):
        raise ValueError(f"无效场景索引：{SCENE_INDEX}，最大允许值：{len(nusc.scene)-1}")
    
    # 获取场景信息
    scene = nusc.scene[SCENE_INDEX]
    sample = nusc.get("sample", scene["first_sample_token"])
    print(f"\n可视化场景：{scene['name']}")

    # 准备可视化
    fig, ax = setup_3d_plot()
    #title = f"Radar Point Cloud Visualization\nScene: {scene['name']} ({scene['location']})"
    title = f"Radar Point Cloud Visualization\nScene: {scene['name']}"
    plt.title(title, fontsize=12)

    # 加载并绘制所有雷达数据
    print("\n正在处理雷达数据：")
    for radar_name in RADAR_NAMES:
        print(f"  - 处理 {radar_name}...", end="", flush=True)
        try:
            points = load_radar_data(nusc, sample, radar_name)
            plot_radar_points(ax, points, COLOR_MAP[radar_name], radar_name)
            print("完成")
        except Exception as e:
            print(f"\n  错误：处理{radar_name}时发生异常 - {str(e)}")
            continue

    # 添加图例和调整布局
    ax.legend(loc="upper right", fontsize=8)
    plt.tight_layout()
    
    # 显示可视化结果
    print("\n正在生成可视化图表...")
    plt.show()

if __name__ == "__main__":
    main()